"""
Tests for MCP SEP-1686 task protocol support through mounted servers.

Verifies that tasks work seamlessly when calling tools/prompts/resources
on mounted child servers through a parent server.
"""

import asyncio

import pytest
from docket import Docket

from fastmcp import FastMCP, TaskConfig
from fastmcp.client import Client
from fastmcp.server.dependencies import CurrentDocket, CurrentFastMCP


@pytest.fixture(autouse=True)
def reset_docket_memory_server():
    """Reset the shared Docket memory server between tests.

    Docket uses a class-level FakeServer instance for memory:// URLs which
    persists between tests, causing test isolation issues. This fixture
    clears that shared state before each test.
    """
    # Clear the shared FakeServer before each test
    if hasattr(Docket, "_memory_server"):
        delattr(Docket, "_memory_server")
    yield
    # Clean up after test as well
    if hasattr(Docket, "_memory_server"):
        delattr(Docket, "_memory_server")


@pytest.fixture
def child_server():
    """Create a child server with task-enabled components."""
    mcp = FastMCP("child-server")

    @mcp.tool(task=True)
    async def multiply(a: int, b: int) -> int:
        """Multiply two numbers."""
        return a * b

    @mcp.tool(task=True)
    async def slow_child_tool(duration: float = 0.1) -> str:
        """A child tool that takes time to execute."""
        await asyncio.sleep(duration)
        return "child completed"

    @mcp.tool(task=False)
    async def sync_child_tool(message: str) -> str:
        """Child tool that only supports synchronous execution."""
        return f"child sync: {message}"

    @mcp.prompt(task=True)
    async def child_prompt(topic: str) -> str:
        """A child prompt that can execute as a task."""
        return f"Here is information about {topic} from the child server."

    @mcp.resource("child://data.txt", task=True)
    async def child_resource() -> str:
        """A child resource that can be read as a task."""
        return "Data from child server"

    @mcp.resource("child://item/{item_id}.json", task=True)
    async def child_item_resource(item_id: str) -> str:
        """A child resource template that can execute as a task."""
        return f'{{"itemId": "{item_id}", "source": "child"}}'

    return mcp


@pytest.fixture
def parent_server(child_server):
    """Create a parent server with the child mounted."""
    parent = FastMCP("parent-server")

    @parent.tool(task=True)
    async def parent_tool(value: int) -> int:
        """A tool on the parent server."""
        return value * 10

    # Mount child with prefix
    parent.mount(child_server, prefix="child")

    return parent


@pytest.fixture
def parent_server_no_prefix(child_server):
    """Create a parent server with child mounted without prefix."""
    parent = FastMCP("parent-no-prefix")
    parent.mount(child_server)  # No prefix
    return parent


class TestMountedToolTasks:
    """Test task execution for mounted tools."""

    async def test_mounted_tool_task_returns_task_object(self, parent_server):
        """Mounted tool called with task=True returns a task object."""
        async with Client(parent_server) as client:
            # Tool name is prefixed: child_multiply
            task = await client.call_tool("child_multiply", {"a": 6, "b": 7}, task=True)

            assert task is not None
            assert hasattr(task, "task_id")
            assert isinstance(task.task_id, str)
            assert len(task.task_id) > 0

    async def test_mounted_tool_task_executes_in_background(self, parent_server):
        """Mounted tool task executes in background."""
        async with Client(parent_server) as client:
            task = await client.call_tool("child_multiply", {"a": 3, "b": 4}, task=True)

            # Should execute in background
            assert not task.returned_immediately

    async def test_mounted_tool_task_returns_correct_result(
        self, parent_server: FastMCP
    ):
        """Mounted tool task returns correct result."""
        async with Client(parent_server) as client:
            task = await client.call_tool("child_multiply", {"a": 8, "b": 9}, task=True)

            result = await task.result()
            assert result.data == 72

    async def test_mounted_tool_task_status(self, parent_server):
        """Can poll task status for mounted tool."""
        async with Client(parent_server) as client:
            task = await client.call_tool(
                "child_slow_child_tool", {"duration": 0.5}, task=True
            )

            # Check status while running
            status = await task.status()
            assert status.status in ["working", "completed"]

            # Wait for completion
            await task.wait(timeout=2.0)

            # Check status after completion
            status = await task.status()
            assert status.status == "completed"

    async def test_mounted_tool_task_cancellation(self, parent_server):
        """Can cancel a mounted tool task."""
        async with Client(parent_server) as client:
            task = await client.call_tool(
                "child_slow_child_tool", {"duration": 10.0}, task=True
            )

            # Let it start
            await asyncio.sleep(0.1)

            # Cancel the task
            await task.cancel()

            # Check status
            status = await task.status()
            assert status.status == "cancelled"

    async def test_graceful_degradation_sync_mounted_tool(self, parent_server):
        """Sync-only mounted tool returns error with task=True."""
        async with Client(parent_server) as client:
            task = await client.call_tool(
                "child_sync_child_tool", {"message": "hello"}, task=True
            )

            # Should return immediately with an error
            assert task.returned_immediately

            result = await task.result()
            assert result.is_error

    async def test_parent_and_mounted_tools_both_work(self, parent_server):
        """Both parent and mounted tools work as tasks."""
        async with Client(parent_server) as client:
            # Parent tool
            parent_task = await client.call_tool("parent_tool", {"value": 5}, task=True)
            # Mounted tool
            child_task = await client.call_tool(
                "child_multiply", {"a": 2, "b": 3}, task=True
            )

            parent_result = await parent_task.result()
            child_result = await child_task.result()

            assert parent_result.data == 50
            assert child_result.data == 6


class TestMountedToolTasksNoPrefix:
    """Test task execution for mounted tools without prefix."""

    async def test_mounted_tool_without_prefix_task_works(
        self, parent_server_no_prefix
    ):
        """Mounted tool without prefix works as task."""
        async with Client(parent_server_no_prefix) as client:
            # No prefix, so tool keeps original name
            task = await client.call_tool("multiply", {"a": 5, "b": 6}, task=True)

            assert not task.returned_immediately

            result = await task.result()
            assert result.data == 30


class TestMountedPromptTasks:
    """Test task execution for mounted prompts."""

    async def test_mounted_prompt_task_returns_task_object(self, parent_server):
        """Mounted prompt called with task=True returns a task object."""
        async with Client(parent_server) as client:
            # Prompt name is prefixed: child_child_prompt
            task = await client.get_prompt(
                "child_child_prompt", {"topic": "FastMCP"}, task=True
            )

            assert task is not None
            assert hasattr(task, "task_id")
            assert isinstance(task.task_id, str)

    async def test_mounted_prompt_task_executes_in_background(self, parent_server):
        """Mounted prompt task executes in background."""
        async with Client(parent_server) as client:
            task = await client.get_prompt(
                "child_child_prompt", {"topic": "testing"}, task=True
            )

            assert not task.returned_immediately

    async def test_mounted_prompt_task_returns_correct_result(
        self, parent_server: FastMCP
    ):
        """Mounted prompt task returns correct result."""
        async with Client(parent_server) as client:
            task = await client.get_prompt(
                "child_child_prompt", {"topic": "MCP protocol"}, task=True
            )

            result = await task.result()
            assert "MCP protocol" in result.messages[0].content.text
            assert "child server" in result.messages[0].content.text


class TestMountedResourceTasks:
    """Test task execution for mounted resources."""

    async def test_mounted_resource_task_returns_task_object(self, parent_server):
        """Mounted resource read with task=True returns a task object."""
        async with Client(parent_server) as client:
            # Resource URI is prefixed: child://child/data.txt
            task = await client.read_resource("child://child/data.txt", task=True)

            assert task is not None
            assert hasattr(task, "task_id")
            assert isinstance(task.task_id, str)

    async def test_mounted_resource_task_executes_in_background(self, parent_server):
        """Mounted resource task executes in background."""
        async with Client(parent_server) as client:
            task = await client.read_resource("child://child/data.txt", task=True)

            assert not task.returned_immediately

    async def test_mounted_resource_task_returns_correct_result(self, parent_server):
        """Mounted resource task returns correct result."""
        async with Client(parent_server) as client:
            task = await client.read_resource("child://child/data.txt", task=True)

            result = await task.result()
            assert len(result) > 0
            assert "Data from child server" in result[0].text

    async def test_mounted_resource_template_task(self, parent_server):
        """Mounted resource template with task=True works."""
        async with Client(parent_server) as client:
            task = await client.read_resource("child://child/item/99.json", task=True)

            assert not task.returned_immediately

            result = await task.result()
            assert '"itemId": "99"' in result[0].text
            assert '"source": "child"' in result[0].text


class TestMountedTaskDependencies:
    """Test that dependencies work correctly in mounted task execution."""

    async def test_mounted_task_receives_docket_dependency(self):
        """Mounted tool task receives CurrentDocket dependency."""
        child = FastMCP("dep-child")
        received_docket = []

        @child.tool(task=True)
        async def tool_with_docket(docket: CurrentDocket = CurrentDocket()) -> str:  # type: ignore[invalid-type-form]
            received_docket.append(docket)
            return f"docket available: {docket is not None}"

        parent = FastMCP("dep-parent")
        parent.mount(child, prefix="child")

        async with Client(parent) as client:
            task = await client.call_tool("child_tool_with_docket", {}, task=True)
            result = await task.result()

            assert "docket available: True" in str(result)
            assert len(received_docket) == 1
            assert received_docket[0] is not None

    async def test_mounted_task_receives_server_dependency(self):
        """Mounted tool task receives CurrentFastMCP dependency."""
        child = FastMCP("server-dep-child")
        received_server = []

        @child.tool(task=True)
        async def tool_with_server(server: CurrentFastMCP = CurrentFastMCP()) -> str:  # type: ignore[invalid-type-form]
            received_server.append(server)
            return f"server name: {server.name}"

        parent = FastMCP("server-dep-parent")
        parent.mount(child, prefix="child")

        async with Client(parent) as client:
            task = await client.call_tool("child_tool_with_server", {}, task=True)
            await task.result()

            # The server should be the child server since that's where the tool is defined
            assert len(received_server) == 1
            # Note: It might be parent or child depending on implementation
            assert received_server[0] is not None


class TestMultipleMounts:
    """Test tasks with multiple mounted servers."""

    async def test_tasks_work_with_multiple_mounts(self):
        """Tasks work correctly with multiple mounted servers."""
        child1 = FastMCP("child1")
        child2 = FastMCP("child2")

        @child1.tool(task=True)
        async def add(a: int, b: int) -> int:
            return a + b

        @child2.tool(task=True)
        async def subtract(a: int, b: int) -> int:
            return a - b

        parent = FastMCP("multi-parent")
        parent.mount(child1, prefix="math1")
        parent.mount(child2, prefix="math2")

        async with Client(parent) as client:
            task1 = await client.call_tool("math1_add", {"a": 10, "b": 5}, task=True)
            task2 = await client.call_tool(
                "math2_subtract", {"a": 10, "b": 5}, task=True
            )

            result1 = await task1.result()
            result2 = await task2.result()

            assert result1.data == 15
            assert result2.data == 5


class TestMountedTaskList:
    """Test task listing with mounted servers."""

    async def test_list_tasks_includes_mounted_tasks(self, parent_server):
        """Task list includes tasks from mounted server tools."""
        async with Client(parent_server) as client:
            # Create tasks on both parent and mounted tools
            parent_task = await client.call_tool("parent_tool", {"value": 1}, task=True)
            child_task = await client.call_tool(
                "child_multiply", {"a": 2, "b": 2}, task=True
            )

            # Wait for completion
            await parent_task.wait(timeout=2.0)
            await child_task.wait(timeout=2.0)

            # List all tasks - returns dict with "tasks" key
            tasks_response = await client.list_tasks()

            task_ids = [t["taskId"] for t in tasks_response["tasks"]]
            assert parent_task.task_id in task_ids
            assert child_task.task_id in task_ids


class TestMountedTaskConfigModes:
    """Test TaskConfig mode enforcement for mounted tools."""

    @pytest.fixture
    def child_with_modes(self):
        """Create a child server with tools in all three TaskConfig modes."""
        mcp = FastMCP("child-modes", tasks=False)

        @mcp.tool(task=TaskConfig(mode="optional"))
        async def optional_tool() -> str:
            """Tool that supports both sync and task execution."""
            return "optional result"

        @mcp.tool(task=TaskConfig(mode="required"))
        async def required_tool() -> str:
            """Tool that requires task execution."""
            return "required result"

        @mcp.tool(task=TaskConfig(mode="forbidden"))
        async def forbidden_tool() -> str:
            """Tool that forbids task execution."""
            return "forbidden result"

        return mcp

    @pytest.fixture
    def parent_with_modes(self, child_with_modes):
        """Create a parent server with the child mounted."""
        parent = FastMCP("parent-modes")
        parent.mount(child_with_modes, prefix="child")
        return parent

    async def test_optional_mode_sync_through_mount(self, parent_with_modes):
        """Optional mode tool works without task through mount."""
        async with Client(parent_with_modes) as client:
            result = await client.call_tool("child_optional_tool", {})
            assert "optional result" in str(result)

    async def test_optional_mode_task_through_mount(self, parent_with_modes):
        """Optional mode tool works with task through mount."""
        async with Client(parent_with_modes) as client:
            task = await client.call_tool("child_optional_tool", {}, task=True)
            assert task is not None
            result = await task.result()
            assert result.data == "optional result"

    async def test_required_mode_with_task_through_mount(self, parent_with_modes):
        """Required mode tool succeeds with task through mount."""
        async with Client(parent_with_modes) as client:
            task = await client.call_tool("child_required_tool", {}, task=True)
            assert task is not None
            result = await task.result()
            assert result.data == "required result"

    async def test_required_mode_without_task_through_mount(self, parent_with_modes):
        """Required mode tool errors without task through mount."""
        from fastmcp.exceptions import ToolError

        async with Client(parent_with_modes) as client:
            with pytest.raises(ToolError) as exc_info:
                await client.call_tool("child_required_tool", {})

            assert "requires task-augmented execution" in str(exc_info.value)

    async def test_forbidden_mode_sync_through_mount(self, parent_with_modes):
        """Forbidden mode tool works without task through mount."""
        async with Client(parent_with_modes) as client:
            result = await client.call_tool("child_forbidden_tool", {})
            assert "forbidden result" in str(result)

    async def test_forbidden_mode_with_task_through_mount(self, parent_with_modes):
        """Forbidden mode tool degrades gracefully with task through mount."""
        async with Client(parent_with_modes) as client:
            task = await client.call_tool("child_forbidden_tool", {}, task=True)

            # Should return immediately (graceful degradation)
            assert task.returned_immediately

            result = await task.result()
            # Result is available but may indicate error or sync execution
            assert result is not None
